# (1,64,4096,128) (4096,1,64,128)

from useful_tools import compare_acc

print("start to compare diff...")

ms_query, pt_query = compare_acc.load_npy("ms_query.npy",
                                          "/lgsl_data1/huawei_wz/Megatron-FLM-huawei/numpys/query_after_rotary.npy")
ms_query_new = ms_query.transpose(2, 0, 1, 3)
compare_acc.compare(ms_query_new, pt_query, save_name_prefix="ms_rotory_query")


ms_key, pt_key = compare_acc.load_npy("ms_key.npy",
                                          "/lgsl_data1/huawei_wz/Megatron-FLM-huawei/numpys/key_after_rotary.npy")
ms_key_new = ms_key.transpose(2, 0, 1, 3)
compare_acc.compare(ms_key_new, pt_key, save_name_prefix="ms_rotory_key")


ms_value, pt_value = compare_acc.load_npy("ms0_value.npy","/lgsl_data1/huawei_wz/Megatron-FLM-huawei/numpys/value_after_rotary.npy")
ms_value_new = ms_value.transpose(2, 0, 1, 3)
compare_acc.compare(ms_value_new, pt_value, save_name_prefix="ms_rotory_value")

ms_query, mixed_x_pt = compare_acc.load_npy("ms0_query.npy","/lgsl_data1/huawei_wz/Megatron-FLM-huawei/numpys/mixed_x_layer.npy")
# mixed_x_pt=mixed_x_pt.reshape(4096, 1, 64, 128, 3)
mixed_x_pt=mixed_x_pt.reshape(4096, 1, 64, 3, 128)
compare_acc.compare(ms_query,mixed_x_pt[:,:,:,0,:])
